# Volatility
# Copyright (C) 2007-2013 Volatility Foundation
# Copyright (c) 2010, 2011, 2012 Michael Ligh <michael.ligh@mnin.org>
#
# This file is part of Volatility.
#
# Volatility is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License Version 2 as
# published by the Free Software Foundation.  You may not use, modify or
# distribute this program under any other version of the GNU General
# Public License.
#
# Volatility is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Volatility.  If not, see <http://www.gnu.org/licenses/>.
#

import volatility.utils as utils
import volatility.obj as obj
import volatility.scan as scan
import volatility.debug as debug
import volatility.plugins.common as common
import volatility.win32.modules as modules
import volatility.win32.tasks as tasks
import volatility.plugins.malware.devicetree as devicetree

try:
    import distorm3
    has_distorm3 = True
except ImportError:
    has_distorm3 = False

#--------------------------------------------------------------------------------
# vtypes
#--------------------------------------------------------------------------------

callback_types = {
    '_NOTIFICATION_PACKET' : [ 0x10, {
    'ListEntry' : [ 0x0, ['_LIST_ENTRY']],
    'DriverObject' : [ 0x8, ['pointer', ['_DRIVER_OBJECT']]],
    'NotificationRoutine' : [ 0xC, ['unsigned int']],
    } ],
    '_KBUGCHECK_CALLBACK_RECORD' : [ 0x20, {
    'Entry' : [ 0x0, ['_LIST_ENTRY']],
    'CallbackRoutine' : [ 0x8, ['unsigned int']],
    'Buffer' : [ 0xC, ['pointer', ['void']]],
    'Length' : [ 0x10, ['unsigned int']],
    'Component' : [ 0x14, ['pointer', ['String', dict(length = 64)]]],
    'Checksum' : [ 0x18, ['pointer', ['unsigned int']]],
    'State' : [ 0x1C, ['unsigned char']],
    } ],
    '_KBUGCHECK_REASON_CALLBACK_RECORD' : [ 0x1C, {
    'Entry' : [ 0x0, ['_LIST_ENTRY']],
    'CallbackRoutine' : [ 0x8, ['unsigned int']],
    'Component' : [ 0xC, ['pointer', ['String', dict(length = 8)]]],
    'Checksum' : [ 0x10, ['pointer', ['unsigned int']]],
    'Reason' : [ 0x14, ['unsigned int']],
    'State' : [ 0x18, ['unsigned char']],
    } ],
    '_SHUTDOWN_PACKET' : [ 0xC, {
    'Entry' : [ 0x0, ['_LIST_ENTRY']],
    'DeviceObject' : [ 0x8, ['pointer', ['_DEVICE_OBJECT']]],
    } ],
    '_EX_CALLBACK_ROUTINE_BLOCK' : [ 0x8, {
    'RundownProtect' : [ 0x0, ['unsigned int']],
    'Function' : [ 0x4, ['unsigned int']],
    'Context' : [ 0x8, ['unsigned int']],
    } ],
    '_GENERIC_CALLBACK' : [ 0xC, {
    'Callback' : [ 0x4, ['pointer', ['void']]],
    'Associated' : [ 0x8, ['pointer', ['void']]],
    } ],
    '_REGISTRY_CALLBACK_LEGACY' : [ 0x38, {
    'CreateTime' : [ 0x0, ['WinTimeStamp', dict(is_utc = True)]],
    } ],
    '_REGISTRY_CALLBACK' : [ None, {
    'ListEntry' : [ 0x0, ['_LIST_ENTRY']],
    'Function' : [ 0x1C, ['pointer', ['void']]],
    } ],
    '_DBGPRINT_CALLBACK' : [ 0x14, {
    'Function' : [ 0x8, ['pointer', ['void']]],
    } ],
    '_NOTIFY_ENTRY_HEADER' : [ None, {
    'ListEntry' : [ 0x0, ['_LIST_ENTRY']],
    'EventCategory' : [ 0x8, ['Enumeration', dict(target = 'long', choices = {
            0: 'EventCategoryReserved',
            1: 'EventCategoryHardwareProfileChange',
            2: 'EventCategoryDeviceInterfaceChange',
            3: 'EventCategoryTargetDeviceChange'})]],
    'CallbackRoutine' : [ 0x14, ['unsigned int']],
    'DriverObject' : [ 0x1C, ['pointer', ['_DRIVER_OBJECT']]],
    } ],
}

#--------------------------------------------------------------------------------
# object classes
#--------------------------------------------------------------------------------

class _SHUTDOWN_PACKET(obj.CType):
    """Class for shutdown notification callbacks"""

    def sanity_check(self):
        """
        Perform some checks. 
        Note: obj_native_vm is kernel space.
        """

        if (not self.obj_native_vm.is_valid_address(self.Entry.Flink) or
            not self.obj_native_vm.is_valid_address(self.Entry.Blink) or
            not self.obj_native_vm.is_valid_address(self.DeviceObject)):
            return False

        # Dereference the device object 
        device = self.DeviceObject.dereference()

        # Carve out the device's object header and check its type
        object_header = obj.Object("_OBJECT_HEADER",
                offset = device.obj_offset -
                self.obj_native_vm.profile.get_obj_offset("_OBJECT_HEADER", "Body"),
                vm = device.obj_vm,
                native_vm = device.obj_native_vm)

        return object_header.get_object_type() == "Device"

#--------------------------------------------------------------------------------
# profile modifications 
#--------------------------------------------------------------------------------

class MalwareCallbackMods(obj.ProfileModification):
    before = ['WindowsOverlay']
    conditions = {'os': lambda x: x == 'windows',
                  'memory_model': lambda x: x == '32bit'}
    def modification(self, profile):
        profile.vtypes.update(callback_types)
        profile.object_classes.update({
            '_SHUTDOWN_PACKET': _SHUTDOWN_PACKET,
        })

#--------------------------------------------------------------------------------
# pool scanners
#--------------------------------------------------------------------------------

class AbstractCallbackScanner(scan.PoolScanner):
    """Return the offset of the callback, no object headers"""

    def object_offset(self, found, address_space):
        return found + (address_space.profile.get_obj_size("_POOL_HEADER") -
                        address_space.profile.get_obj_offset("_POOL_HEADER", "PoolTag"))

class PoolScanFSCallback(AbstractCallbackScanner):
    """PoolScanner for File System Callbacks"""
    checks = [ ('PoolTagCheck', dict(tag = "IoFs")),
               ('CheckPoolSize', dict(condition = lambda x: x == 0x18)),
               ('CheckPoolType', dict(non_paged = True, paged = True, free = True)),
               #('CheckPoolIndex', dict(value = 4)),
               ]

class PoolScanShutdownCallback(AbstractCallbackScanner):
    """PoolScanner for Shutdown Callbacks"""
    checks = [ ('PoolTagCheck', dict(tag = "IoSh")),
               ('CheckPoolSize', dict(condition = lambda x: x == 0x18)),
               ('CheckPoolType', dict(non_paged = True, paged = True, free = True)),
               ('CheckPoolIndex', dict(value = 0)),
               ]

class PoolScanGenericCallback(AbstractCallbackScanner):
    """PoolScanner for Generic Callbacks"""
    checks = [ ('PoolTagCheck', dict(tag = "Cbrb")),
               ('CheckPoolSize', dict(condition = lambda x: x == 0x18)),
               ('CheckPoolType', dict(non_paged = True, paged = True, free = True)),
               # This is a good constraint for all images except Frank's rustock-c.vmem
               #('CheckPoolIndex', dict(value = 1)), 
               ]

class PoolScanDbgPrintCallback(AbstractCallbackScanner):
    """PoolScanner for DebugPrint Callbacks on Vista and 7"""
    checks = [ ('PoolTagCheck', dict(tag = "DbCb")),
               ('CheckPoolSize', dict(condition = lambda x: x == 0x20)),
               ('CheckPoolType', dict(non_paged = True, paged = True, free = True)),
               #('CheckPoolIndex', dict(value = 0)), 
               ]

class PoolScanRegistryCallback(AbstractCallbackScanner):
    """PoolScanner for DebugPrint Callbacks on Vista and 7"""
    checks = [ ('PoolTagCheck', dict(tag = "CMcb")),
               # Seen as 0x38 on Vista SP2 and 0x30 on 7 SP0 
               ('CheckPoolSize', dict(condition = lambda x: x >= 0x38)),
               ('CheckPoolType', dict(non_paged = True, paged = True, free = True)),
               ('CheckPoolIndex', dict(value = 4)),
               ]

class PoolScanPnp9(AbstractCallbackScanner):
    """PoolScanner for Pnp9 (EventCategoryHardwareProfileChange)"""
    checks = [ ('PoolTagCheck', dict(tag = "Pnp9")),
               # seen as 0x2C on W7, 0x28 on vistasp0 (4 less but needs 8 less)
               ('CheckPoolSize', dict(condition = lambda x: x >= 0x30)),
               ('CheckPoolType', dict(non_paged = True, paged = True, free = True)),
               ('CheckPoolIndex', dict(value = 1)),
               ]

class PoolScanPnpD(AbstractCallbackScanner):
    """PoolScanner for PnpD (EventCategoryDeviceInterfaceChange)"""
    checks = [ ('PoolTagCheck', dict(tag = "PnpD")),
               # seen as 0x3C on W7, 0x38 on vistasp0 (4 less but needs 8 less)
               ('CheckPoolSize', dict(condition = lambda x: x >= 0x40)),
               ('CheckPoolType', dict(non_paged = True, paged = True, free = True)),
               ('CheckPoolIndex', dict(value = 1)),
               ]

class PoolScanPnpC(AbstractCallbackScanner):
    """PoolScanner for PnpC (EventCategoryTargetDeviceChange)"""
    checks = [ ('PoolTagCheck', dict(tag = "PnpC")),
               # seen as 0x34 on W7, 0x30 on vistasp0 (4 less but needs 8 less)
               ('CheckPoolSize', dict(condition = lambda x: x >= 0x38)),
               ('CheckPoolType', dict(non_paged = True, paged = True, free = True)),
               ('CheckPoolIndex', dict(value = 1)),
               ]

#--------------------------------------------------------------------------------
# callbacks plugin
#--------------------------------------------------------------------------------

class Callbacks(common.AbstractWindowsCommand):
    "Print system-wide notification routines"

    @staticmethod
    def is_valid_profile(profile):
        return (profile.metadata.get('os', 'unknown') == 'windows' and
                profile.metadata.get('memory_model', '32bit') == '32bit')

    def __init__(self, *args, **kwargs):
        common.AbstractWindowsCommand.__init__(self, *args, **kwargs)
        self.phys_space = None
        self.kern_space = None

    @staticmethod
    def get_kernel_callbacks(nt_mod):
        """
        Enumerate the Create Process, Create Thread, and Image Load callbacks.

        On some systems, the byte sequences will be inaccurate or the exported 
        function will not be found. In these cases, the PoolScanGenericCallback
        scanner will pick up the pool associated with the callbacks.
        """

        routines = [
                   # push esi; mov esi, offset _PspLoadImageNotifyRoutine
                   ('PsSetLoadImageNotifyRoutine', "\x56\xbe"),
                   # push esi; mov esi, offset _PspCreateThreadNotifyRoutine
                   ('PsSetCreateThreadNotifyRoutine', "\x56\xbe"),
                   # mov edi, offset _PspCreateProcessNotifyRoutine
                   ('PsSetCreateProcessNotifyRoutine', "\xbf"),
                   ]

        for symbol, hexbytes in routines:

            # Locate the exported symbol in the NT module
            symbol_rva = nt_mod.getprocaddress(symbol)
            if symbol_rva == None:
                continue

            symbol_address = symbol_rva + nt_mod.DllBase

            # Find the global variable referenced by the exported symbol
            data = nt_mod.obj_vm.zread(symbol_address, 100)

            offset = data.find(hexbytes)
            if offset == -1:
                continue

            # Read the pointer to the list 
            p = obj.Object('Pointer',
                    offset = symbol_address + offset + len(hexbytes),
                    vm = nt_mod.obj_vm)

            # The list is an array of 8 _EX_FAST_REF objects
            addrs = obj.Object('Array', count = 8, targetType = '_EX_FAST_REF',
                    offset = p, vm = nt_mod.obj_vm)

            for addr in addrs:
                callback = addr.dereference_as("_GENERIC_CALLBACK")
                if callback:
                    yield symbol, callback.Callback, None

    def get_fs_callbacks(self):
        """Enumerate the File System change callbacks"""

        for offset in PoolScanFSCallback().scan(self.phys_space):
            callback = obj.Object('_NOTIFICATION_PACKET', offset, self.phys_space)
            yield "IoRegisterFsRegistrationChange", callback.NotificationRoutine, None

    def get_shutdown_callbacks(self):
        """Enumerate shutdown notification callbacks"""

        for offset in PoolScanShutdownCallback().scan(self.phys_space):

            # Instantiate the object in physical space but give it a native
            # VM of kernel space 
            callback = obj.Object('_SHUTDOWN_PACKET',
                            offset = offset,
                            vm = self.phys_space,
                            native_vm = self.kern_space)

            if not callback.sanity_check():
                continue

            # Get the callback's driver object. We've already
            # checked the sanity of the device object pointer. 
            driver_obj = callback.DeviceObject.dereference().DriverObject

            address = driver_obj.MajorFunction[devicetree.MAJOR_FUNCTIONS.index('IRP_MJ_SHUTDOWN')]
            details = str(driver_obj.DriverName)

            yield "IoRegisterShutdownNotification", address, details

    def get_bugcheck_callbacks(self):
        """
        Enumerate generic Bugcheck callbacks.

        Note: These structures don't exist in tagged pools, but you can find 
        them via KDDEBUGGER_DATA64 on all versions of Windows.
        """

        kbcclh = tasks.get_kdbg(self.kern_space).KeBugCheckCallbackListHead.dereference_as('_KBUGCHECK_CALLBACK_RECORD')

        for l in kbcclh.Entry.list_of_type("_KBUGCHECK_CALLBACK_RECORD", "Entry"):
            yield "KeBugCheckCallbackListHead", l.CallbackRoutine, l.Component.dereference()

    @staticmethod
    def get_registry_callbacks_legacy(nt_mod):
        """
        Enumerate registry change callbacks.

        This method of finding a global variable via disassembly of the 
        CmRegisterCallback function is only for XP systems. If it fails on 
        XP you can still find the callbacks using PoolScanGenericCallback. 

        On Vista and Windows 7, these callbacks are registered using the 
        CmRegisterCallbackEx function. 
        """

        if not has_distorm3:
            return

        symbol = "CmRegisterCallback"

        # Get the RVA of the symbol from NT's EAT
        symbol_rva = nt_mod.getprocaddress(symbol)
        if symbol_rva == None:
            return

        # Absolute VA to the symbol code 
        symbol_address = symbol_rva + nt_mod.DllBase

        # Read the function prologue 
        data = nt_mod.obj_vm.zread(symbol_address, 200)

        c = 0
        vector = None

        # Looking for MOV EBX, CmpCallBackVector
        # This may be the first or second MOV EBX instruction
        for op in distorm3.Decompose(symbol_address, data, distorm3.Decode32Bits):
            if op.valid and op.mnemonic == "MOV" and len(op.operands) == 2 and op.operands[0].name == 'EBX':
                vector = op.operands[1].value
                if c == 1:
                    break
                else:
                    c += 1

        # Can't find the global variable 
        if vector == None:
            return

        # The vector is an array of 100 _EX_FAST_REF objects
        addrs = obj.Object("Array", count = 100, offset = vector,
                    vm = nt_mod.obj_vm, targetType = "_EX_FAST_REF")

        for addr in addrs:
            callback = addr.dereference_as("_EX_CALLBACK_ROUTINE_BLOCK")
            if callback:
                yield symbol, callback.Function, None

    def get_generic_callbacks(self):
        """
        Enumerate generic callbacks of the following types:

        * PsSetCreateProcessNotifyRoutine
        * PsSetThreadCreateNotifyRoutine
        * PsSetLoadImageNotifyRoutine
        * CmRegisterCallback (on XP only)
        * DbgkLkmdRegisterCallback (on Windows 7 only)

        The only issue is that you can't distinguish between the types by just 
        finding the generic callback structure 
        """

        for offset in PoolScanGenericCallback().scan(self.phys_space):
            callback = obj.Object('_GENERIC_CALLBACK', offset, self.phys_space)
            yield "GenericKernelCallback", callback.Callback, None

    def get_dbgprint_callbacks(self):
        """Enumerate DebugPrint callbacks on Vista and 7"""

        for offset in PoolScanDbgPrintCallback().scan(self.phys_space):
            callback = obj.Object('_DBGPRINT_CALLBACK', offset, self.phys_space)
            yield "DbgSetDebugPrintCallback", callback.Function, None

    def get_registry_callbacks(self):
        """
        Enumerate registry callbacks on Vista and 7.

        These callbacks are installed via CmRegisterCallback
        or CmRegisterCallbackEx.
        """

        for offset in PoolScanRegistryCallback().scan(self.phys_space):
            callback = obj.Object('_REGISTRY_CALLBACK', offset, self.phys_space)
            yield "CmRegisterCallback", callback.Function, None

    def get_pnp_callbacks(self):
        """Enumerate IoRegisterPlugPlayNotification"""

        offsets = []

        for offset in PoolScanPnp9().scan(self.phys_space):
            offsets.append(offset)

        for offset in PoolScanPnpD().scan(self.phys_space):
            offsets.append(offset)

        for offset in PoolScanPnpC().scan(self.phys_space):
            offsets.append(offset)

        for offset in offsets:
            entry = obj.Object("_NOTIFY_ENTRY_HEADER", offset = offset,
                        vm = self.phys_space, native_vm = self.kern_space)

            # Dereference the driver object pointer
            driver = entry.DriverObject.dereference()

            # Instantiate an object header for the driver name 
            header = obj.Object("_OBJECT_HEADER", offset = driver.obj_offset -
                driver.obj_vm.profile.get_obj_offset("_OBJECT_HEADER", "Body"),
                vm = driver.obj_vm,
                native_vm = driver.obj_native_vm)

            # Grab the object name 
            driver_name = header.NameInfo.Name.v()

            yield entry.EventCategory, entry.CallbackRoutine, driver_name

    @staticmethod
    def get_bugcheck_reason_callbacks(nt_mod):
        """
        Enumerate Bugcheck Reason callbacks.

        Note: These structures don't exist in tagged pools, so we 
        find them by locating the list head which is a non-exported 
        NT symbol. The method works on all x86 versions of Windows. 

        mov [eax+KBUGCHECK_REASON_CALLBACK_RECORD.Entry.Blink], \
                offset _KeBugCheckReasonCallbackListHead
        """

        symbol = "KeRegisterBugCheckReasonCallback"
        hexbytes = "\xC7\x40\x04"

        # Locate the symbol RVA 
        symbol_rva = nt_mod.getprocaddress(symbol)
        if symbol_rva == None:
            return

        # Compute the absolute virtual address 
        symbol_address = symbol_rva + nt_mod.DllBase

        data = nt_mod.obj_vm.zread(symbol_address, 100)

        # Search for the pattern 
        offset = data.find(hexbytes)
        if offset == -1:
            return

        p = obj.Object('Pointer',
                offset = symbol_address + offset + len(hexbytes),
                vm = nt_mod.obj_vm)

        bugs = p.dereference_as('_KBUGCHECK_REASON_CALLBACK_RECORD')

        for l in bugs.Entry.list_of_type("_KBUGCHECK_REASON_CALLBACK_RECORD", "Entry"):
            yield symbol, l.CallbackRoutine, l.Component.dereference()

    def calculate(self):
        # All scanners will share a kernel and physical space 
        self.kern_space = utils.load_as(self._config)
        self.phys_space = utils.load_as(self._config, astype = 'physical')

        # We currently dont support x64
        if not self.is_valid_profile(self.kern_space.profile):
            debug.error("This command does not support the selected profile.")

        # Get the OS version we're analyzing
        version = (self.kern_space.profile.metadata.get('major', 0),
                   self.kern_space.profile.metadata.get('minor', 0))

        modlist = list(modules.lsmod(self.kern_space))
        mods = dict((self.kern_space.address_mask(mod.DllBase), mod) for mod in modlist)
        mod_addrs = sorted(mods.keys())

        # First few routines are valid on all OS versions 
        for info in self.get_fs_callbacks():
            yield info, mods, mod_addrs

        for info in self.get_bugcheck_callbacks():
            yield info, mods, mod_addrs

        for info in self.get_shutdown_callbacks():
            yield info, mods, mod_addrs

        for info in self.get_generic_callbacks():
            yield info, mods, mod_addrs

        for info in self.get_bugcheck_reason_callbacks(modlist[0]):
            yield info, mods, mod_addrs

        for info in self.get_kernel_callbacks(modlist[0]):
            yield info, mods, mod_addrs

        # Valid for Vista and later
        if version >= (6, 0):
            for info in self.get_dbgprint_callbacks():
                yield info, mods, mod_addrs

            for info in self.get_registry_callbacks():
                yield info, mods, mod_addrs

            for info in self.get_pnp_callbacks():
                yield info, mods, mod_addrs

        # Valid for XP 
        if version == (5, 1):
            for info in self.get_registry_callbacks_legacy(modlist[0]):
                yield info, mods, mod_addrs

    def render_text(self, outfd, data):

        self.table_header(outfd,
                        [("Type", "36"),
                         ("Callback", "[addrpad]"),
                         ("Module", "20"),
                         ("Details", ""),
                        ])

        for (sym, cb, detail), mods, mod_addrs in data:

            module = tasks.find_module(mods, mod_addrs, self.kern_space.address_mask(cb))

            ## The original callbacks plugin searched driver objects
            ## if the owning module isn't found (Rustock.B). We leave that 
            ## task up to the user this time, and will be incoporating 
            ## some different module association methods later. 
            if module:
                module_name = module.BaseDllName or module.FullDllName
            else:
                module_name = "UNKNOWN"

            self.table_row(outfd, sym, cb, module_name, detail or "-")
